First order: Harmonise results a bit more.
rm(list=ls())
seed<-1909
# loading & modifying data
library("readr") # to read the data
library("dplyr") # to manipulate data
library("fastDummies") # create dummies
# charts & tables
library("ggplot2") # to create charts
library("patchwork") # to combine charts
library("flextable") # design tables
library("modelsummary") # structure tables
library("kableExtra") # design table
library("estimatr")
library("ggpubr")
# regression & analysis
library("fixest") # high dimensional FE
library("skimr") # skim the data
# machine learning
library("policytree") # policy tree (Athey & Wager, 2021)
library("grf") # causal forest
library("rsample") # data splitting
library("randomForest") # Traditional Random Forests
library("mlr3") # learners
library("mlr3learners") # learners
library("gbm") # Generalized Boosted Regression
library("DoubleML") # Double ML
# load full dataset
df_repl<-read_delim("../data/FARS-data-full-sample.txt",delim = "\t")%>%
filter(year<2004)%>%
select(-starts_with("imp"))
# load small dataset
df_sel<-read_delim("../data/FARS-data-selection-sample.txt",delim = "\t")%>%
filter(year<2004)%>%
select(-starts_with("imp"))
# remove rows with missing cases
df_repl<-df_repl[complete.cases(df_repl), ]
df_sel<-df_sel[complete.cases(df_sel), ]
# print number of obs
print(paste('Number of observations in the data:',nrow(df_repl),' (full sample);',nrow(df_sel), ' (selected/causal sample)'))
## [1] "Number of observations in the data: 38455 (full sample); 10328 (selected/causal sample)"
# Treatment indicators
df_repl<-df_repl%>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
childseat==1~"Childseat",TRUE~"NONE"),
D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
df_sel <-df_sel %>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
childseat==1~"Childseat",TRUE~"NONE"),
D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
# Convert categorical to indicators
df_repl<-dummy_cols(df_repl%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm)
df_sel<-dummy_cols(df_sel%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm)
#df_repl<-df_repl%>%mutate(day=ifelse(crashtm=="1_day",1,0),night=ifelse(crashtm=="2_night",1,0),morn=ifelse(crashtm=="3_morn",1,0))
#df_sel<- df_sel %>%mutate(day=ifelse(crashtm=="1_day",1,0),night=ifelse(crashtm=="2_night",1,0),morn=ifelse(crashtm=="3_morn",1,0))
# Select variables
df_repl<-df_repl%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death,D,Dbinary,modelyr,age,year)
df_sel<- df_sel %>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death,D,Dbinary,modelyr,age,year)
# Training and test data
set.seed(seed)
df_repl_split <- initial_split(df_repl, prop = .5)
df_repl_train <- training(df_repl_split)
df_repl_test <- testing(df_repl_split)
df_sel_split <- initial_split(df_sel, prop = .5)
df_sel_train <- training(df_sel_split)
df_sel_test <- testing(df_sel_split)
# X Matrices
X_repl_train<-as.matrix(df_repl_train%>%select(-death,-D,-Dbinary))
X_repl_test<- as.matrix(df_repl_test%>%select(-death,-D,-Dbinary))
X_sel_train<- as.matrix(df_sel_train%>%select(-death,-D,-Dbinary))
X_sel_test<- as.matrix(df_sel_test%>%select(-death,-D,-Dbinary))
X_repl_train_nocontrols<-as.matrix(rep(1,nrow(X_repl_train)))
X_repl_test_nocontrols<- as.matrix(rep(1,nrow(X_repl_test)))
X_sel_train_nocontrols<- as.matrix(rep(1,nrow(X_sel_train)))
X_sel_test_nocontrols<- as.matrix(rep(1,nrow(X_sel_test)))
# D matrices
D_repl_train<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_repl_test<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_train<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_test<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_binary_repl_train<-as.matrix(df_repl_train%>%select(Dbinary))
D_binary_repl_test<- as.matrix(df_repl_test%>%select(Dbinary))
D_binary_sel_train<- as.matrix(df_sel_train%>%select(Dbinary))
D_binary_sel_test<- as.matrix(df_sel_test%>%select(Dbinary))
# Y matrices
Y_repl_train<-as.matrix(df_repl_train%>%select(death))
Y_repl_test<- as.matrix(df_repl_test%>%select(death))
Y_sel_train<- as.matrix(df_sel_train%>%select(death))
Y_sel_test<- as.matrix(df_sel_test%>%select(death))
tmp <- df_sel%>%select(splmU55,thoulbs_I,modelyr,year,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death)
# remove missing and rescale
tmp_list <- lapply(tmp, na.omit)
tmp_list <- lapply(tmp_list, scale)
emptycol = function(x) " "
datasummary(splmU55+thoulbs_I+modelyr+year+numcrash+weekend+lowviol+highviol+ruralrd+frimp+suv+death ~ Mean + SD + Heading("Boxplot") * emptycol + Heading("Histogram") * emptycol, data = tmp) %>%
column_spec(column = 4, image = spec_boxplot(tmp_list)) %>%
column_spec(column = 5, image = spec_hist(tmp_list))
| Mean | SD | Boxplot | Histogram | |
|---|---|---|---|---|
| splmU55 | 0.88 | 0.33 | ||
| thoulbs_I | 2.45 | 1.54 | ||
| modelyr | 1987.09 | 8.30 | ||
| year | 1993.42 | 7.13 | ||
| numcrash | 6.63 | 4.51 | ||
| weekend | 0.40 | 0.49 | ||
| lowviol | 0.29 | 0.45 | ||
| highviol | 0.08 | 0.27 | ||
| ruralrd | 0.08 | 0.28 | ||
| frimp | 0.67 | 0.47 | ||
| suv | 0.10 | 0.29 | ||
| death | 0.04 | 0.20 |
The next cell initializes the DML model, fits them twice (the first time without controls.
# Create DML object
dml_data_nocontrols = double_ml_data_from_matrix(y=Y_repl_train,d=D_binary_repl_train,X_repl_train_nocontrols)
dml_data_controls = double_ml_data_from_matrix(y=Y_repl_train,d=D_binary_repl_train,X_repl_train)
# Initiate earners
lgr::get_logger("mlr3")$set_threshold("warn")
learner=lrn(eval_metric="logloss","classif.xgboost")
ml_m = learner$clone()
learner=lrn(objective ='reg:squarederror',"regr.xgboost")
ml_g = learner$clone()
# Estimate DML without controls
obj_dml = DoubleMLPLR$new(dml_data_nocontrols, ml_g=ml_g, ml_m=ml_m)
obj_dml$fit()
print("------------- No controls ------------- ")
## [1] "------------- No controls ------------- "
print(obj_dml)
## ================= DoubleMLPLR Object ==================
##
##
## ------------------ Data summary ------------------
## Outcome variable: y
## Treatment variable(s): d
## Covariates: X1
## Instrument(s):
## No. Observations: 19227
##
## ------------------ Score & algorithm ------------------
## Score function: partialling out
## DML algorithm: dml2
##
## ------------------ Machine learner ------------------
## ml_g: regr.xgboost
## ml_m: classif.xgboost
##
## ------------------ Resampling ------------------
## No. folds: 5
## No. repeated sample splits: 1
## Apply cross-fitting: TRUE
##
## ------------------ Fit summary ------------------
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## d -0.100617 0.006553 -15.35 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
# Estimate DML with controls
obj_dml = DoubleMLPLR$new(dml_data_controls, ml_g=ml_g, ml_m=ml_m)
obj_dml$fit()
cat("\n\n\n")
print("------------- With controls ------------- ")
## [1] "------------- With controls ------------- "
print(obj_dml)
## ================= DoubleMLPLR Object ==================
##
##
## ------------------ Data summary ------------------
## Outcome variable: y
## Treatment variable(s): d
## Covariates: X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12
## Instrument(s):
## No. Observations: 19227
##
## ------------------ Score & algorithm ------------------
## Score function: partialling out
## DML algorithm: dml2
##
## ------------------ Machine learner ------------------
## ml_g: regr.xgboost
## ml_m: classif.xgboost
##
## ------------------ Resampling ------------------
## No. folds: 5
## No. repeated sample splits: 1
## Apply cross-fitting: TRUE
##
## ------------------ Fit summary ------------------
## Estimates and significance testing of the effect of target variables
## Estimate. Std. Error t value Pr(>|t|)
## d -0.101753 0.006971 -14.6 <2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
print("------------- No controls ------------- ")
## [1] "------------- No controls ------------- "
cfnocontrols <- multi_arm_causal_forest(X=X_sel_train_nocontrols, Y=Y_sel_train, W=D_sel_train)
average_treatment_effect(cfnocontrols)
## estimate std.err contrast outcome
## Lapbelt - NONE -0.05339767 0.005772780 Lapbelt - NONE death
## LapShoulderSeat - NONE -0.04656707 0.005979402 LapShoulderSeat - NONE death
## Childseat - NONE -0.04718428 0.005736955 Childseat - NONE death
cat("\n\n\n")
print("------------- With controls ------------- ")
## [1] "------------- With controls ------------- "
cfcontrols <- multi_arm_causal_forest(X=X_repl_train, Y=Y_repl_train, W=D_repl_train)
average_treatment_effect(cfcontrols)
## estimate std.err contrast outcome
## Lapbelt - NONE NaN NaN Lapbelt - NONE death
## LapShoulderSeat - NONE NaN NaN LapShoulderSeat - NONE death
## Childseat - NONE NaN NaN Childseat - NONE death
cfbinary<- causal_forest(X=X_repl_train, Y=Y_repl_train, W=D_binary_repl_train,tune.parameters = "all")
average_treatment_effect(cfbinary)
## estimate std.err
## -0.113805346 0.008314345
cfbinary<- causal_forest(X=X_sel_train, Y=Y_sel_train, W=D_binary_sel_train,tune.parameters = "all")
average_treatment_effect(cfbinary)
## estimate std.err
## -0.061587470 0.008554263
cfbinary$tuning.output
## Tuning status: default.
## This indicates tuning was attempted. However, we could not find parameters that were expected to perform better than default:
##
## sample.fraction: 0.5
## mtry: 12
## min.node.size: 5
## honesty.fraction: 0.5
## honesty.prune.leaves: TRUE
## alpha: 0.05
## imbalance.penalty: 0
Y.regforest = regression_forest(X_sel_train, Y_sel_train)
D.regforest = regression_forest(X_sel_train, D_binary_sel_train)
Below I estimate the basic OLS
# Fit OLS
olsY<-lm(death~.,data=df_sel_train%>%select(-D,-Dbinary))
olsD<-lm(Dbinary~.,data=df_sel_train%>%select(-death,-D))
# Print
print("---- OLS for Y ----")
## [1] "---- OLS for Y ----"
summary(olsY)
##
## Call:
## lm(formula = death ~ ., data = df_sel_train %>% select(-D, -Dbinary))
##
## Residuals:
## Min 1Q Median 3Q Max
## -0.13763 -0.05059 -0.03876 -0.02876 1.00188
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 2.2327350 0.9015832 2.476 0.0133 *
## splmU55 -0.0453104 0.0093381 -4.852 1.26e-06 ***
## thoulbs_I -0.0008212 0.0018972 -0.433 0.6651
## numcrash -0.0007848 0.0006693 -1.173 0.2410
## weekend 0.0105586 0.0058113 1.817 0.0693 .
## lowviol 0.0028352 0.0062891 0.451 0.6521
## highviol 0.0445936 0.0106970 4.169 3.11e-05 ***
## ruralrd -0.0079654 0.0105910 -0.752 0.4520
## frimp -0.0023900 0.0060718 -0.394 0.6939
## suv -0.0187080 0.0099814 -1.874 0.0609 .
## modelyr -0.0007219 0.0005998 -1.204 0.2288
## age -0.0024101 0.0020236 -1.191 0.2337
## year -0.0003525 0.0007010 -0.503 0.6151
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.2033 on 5151 degrees of freedom
## Multiple R-squared: 0.01085, Adjusted R-squared: 0.008544
## F-statistic: 4.708 on 12 and 5151 DF, p-value: 1.085e-07
print("---- OLS for D ----")
## [1] "---- OLS for D ----"
summary(olsD)
##
## Call:
## lm(formula = Dbinary ~ ., data = df_sel_train %>% select(-death,
## -D))
##
## Residuals:
## Min 1Q Median 3Q Max
## -1.0959 -0.2654 0.0865 0.2839 0.9651
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) -72.684428 1.778280 -40.873 < 2e-16 ***
## splmU55 0.063453 0.018418 3.445 0.000575 ***
## thoulbs_I 0.007469 0.003742 1.996 0.045984 *
## numcrash -0.001470 0.001320 -1.113 0.265608
## weekend -0.005567 0.011462 -0.486 0.627237
## lowviol 0.043783 0.012405 3.530 0.000420 ***
## highviol -0.044867 0.021099 -2.127 0.033508 *
## ruralrd -0.074257 0.020890 -3.555 0.000382 ***
## frimp -0.034271 0.011976 -2.862 0.004232 **
## suv 0.018379 0.019687 0.934 0.350589
## modelyr 0.014444 0.001183 12.210 < 2e-16 ***
## age -0.023478 0.003991 -5.882 4.3e-09 ***
## year 0.022411 0.001383 16.208 < 2e-16 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.401 on 5151 degrees of freedom
## Multiple R-squared: 0.3005, Adjusted R-squared: 0.2988
## F-statistic: 184.4 on 12 and 5151 DF, p-value: < 2.2e-16
Let us compare the out of sample performance
# R-squared
r2 <-function(preds,actual){
return(1- sum((preds - actual) ^ 2)/sum((actual - mean(actual))^2))
}
r2_olsY<-r2(predict(olsY,newdata=df_sel_test),df_sel_test$death)
r2_olsD<-r2(predict(olsD,newdata=df_sel_test),df_sel_test$Dbinary)
r2_rfY<-r2(predict(Y.regforest,newdata=X_sel_test)$predictions,df_sel_test$death)
r2_rfD<-r2(predict(D.regforest,newdata=X_sel_test)$predictions,df_sel_test$Dbinary)
data.frame(Method=c("OLS","RF"),R2_D=c(r2_olsD,r2_rfD),R2_Y=c(r2_olsY,r2_rfY))
## Method R2_D R2_Y
## 1 OLS 0.2922920 0.005642891
## 2 RF 0.3231286 0.012266324
plotdata<-data.frame(what=cfbinary$W.hat)
ggplot(plotdata,aes(x=what))+geom_histogram(bins=100,fill="#f56c42",color="white")+xlim(0,1)+
theme_minimal()
### Diagnostic tests
test_calibration(cfbinary)
##
## Best linear fit using forest predictions (on held-out data)
## as well as the mean forest prediction as regressors, along
## with one-sided heteroskedasticity-robust (HC3) SEs:
##
## Estimate Std. Error t value Pr(>t)
## mean.forest.prediction 1.00739 0.13639 7.3863 8.758e-14 ***
## differential.forest.prediction 0.62310 0.32952 1.8909 0.02935 *
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
# Get importance
importance=variable_importance(cfbinary)
var_imp <- data.frame(importance=importance,names=colnames(X_sel_train))
ggplot(var_imp,aes(x= reorder(names,importance),y=importance))+
geom_bar(stat="identity",fill="#f56c42",color="white")+
theme_minimal()+
theme(axis.text.x = element_text(angle=45,vjust = 1, hjust=1))+
labs(x=" ")+
coord_flip()
### Characterising treatment effect heterogeneity
CATE distribution
# get predictions
cate<-data.frame(sample="CATEs",tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau))+
geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.95, position = "identity",
fill="#f56c42",color="white")+
theme_minimal()+
labs(title=" ",x="Conditional Average Treatment Effect",y="Density")
Plot quartiles
# Split sample in 5 groups based on cates
df_sel_train["categroup"] <- factor(ntile(predict(cfbinary)$predictions, n=4))
# calculate AIPW for each sub group
estimated_aipw_ate <- lapply(
seq(4), function(w) {
ate <- average_treatment_effect(cfbinary, subset = df_sel_train$categroup == w,method = "AIPW")
})
# Combine in data da frame
estimated_aipw_ate <- data.frame(do.call(rbind, estimated_aipw_ate))
estimated_aipw_ate$Ntile <- as.numeric(rownames(estimated_aipw_ate))
# create plot
ggplot(estimated_aipw_ate) +
geom_pointrange(aes(x = Ntile, y = estimate, ymax = estimate + 1.96 * `std.err`, ymin = estimate - 1.96 * `std.err`),
size = 1,
position = position_dodge(width = .5)) +
theme_minimal() +
geom_hline(yintercept=0,linetype="dashed")+
labs(x = "Quartile", y = "AIPW ATE", title = "AIPW ATEs by quartiles of the conditional average treatment effect")
# create table
datasummary_balance(~categroup,
data = sumstatdata<-df_sel_train%>%filter(categroup%in%c(1,4))%>%select(-D),
title = "Comparison of the first vs fourth quartile",
fmt= '%.3f',
dinm_statistic = "p.value")
| Mean | Std. Dev. | Mean | Std. Dev. | Mean | Std. Dev. | Mean | Std. Dev. | Diff. in Means | p | |
|---|---|---|---|---|---|---|---|---|---|---|
| splmU55 | 0.841 | 0.366 | 0.969 | 0.173 | 0.128 | 0.000 | ||||
| thoulbs_I | 2.129 | 1.071 | 3.468 | 1.280 | 1.339 | 0.000 | ||||
| numcrash | 6.760 | 4.897 | 6.239 | 2.299 | -0.521 | 0.001 | ||||
| weekend | 0.388 | 0.487 | 0.394 | 0.489 | 0.006 | 0.747 | ||||
| lowviol | 0.341 | 0.474 | 0.211 | 0.409 | -0.129 | 0.000 | ||||
| highviol | 0.125 | 0.331 | 0.016 | 0.127 | -0.108 | 0.000 | ||||
| ruralrd | 0.054 | 0.227 | 0.104 | 0.305 | 0.050 | 0.000 | ||||
| frimp | 0.559 | 0.497 | 0.833 | 0.373 | 0.273 | 0.000 | ||||
| suv | 0.030 | 0.171 | 0.156 | 0.363 | 0.126 | 0.000 | ||||
| death | 0.059 | 0.235 | 0.025 | 0.156 | -0.034 | 0.000 | ||||
| Dbinary | 0.700 | 0.458 | 0.583 | 0.493 | -0.117 | 0.000 | ||||
| modelyr | 1987.620 | 6.594 | 1986.474 | 9.902 | -1.146 | 0.001 | ||||
| age | 3.572 | 1.282 | 4.129 | 1.459 | 0.556 | 0.000 | ||||
| year | 1993.371 | 6.125 | 1992.638 | 8.486 | -0.733 | 0.012 |
Now by covariates
df_sel_train["tau"]<-predict(cfbinary)$predictions
df_sel_train_col<-df_sel_train%>%
group_by(modelyr,splmU55)%>%
summarise(tau=mean(tau))
p1<-ggplot(df_sel_train_col,aes(x=modelyr,y=tau,color=as.factor(splmU55)))+geom_point()+
ylim(-0.125,0)
df_sel_train_col<-df_sel_train%>%
group_by(year,splmU55)%>%
summarise(tau=mean(tau))
p2<-ggplot(df_sel_train_col,aes(x=year,y=tau,color=as.factor(splmU55)))+geom_point()+
ylim(-0.125,0)+labs(y="")
df_sel_train_col<-df_sel_train%>%
group_by(thoulbs_I)%>%
summarise(tau=mean(tau))
p3<-ggplot(df_sel_train_col,aes(x=thoulbs_I*1000,y=tau))+geom_point()+
ylim(-0.125,0)+labs(y="")
ggarrange(p1, p2, p3, ncol=3, nrow=1, common.legend = TRUE, legend="bottom")
CATE distribution by speed limit
# get predictions
cate<-data.frame(sample="CATEs",splmU55=df_sel_train$splmU55,tau=predict(cfbinary)$predictions)
# histogram all
ggplot(cate,aes(x=tau,fill=as.factor(splmU55),group=splmU55))+
geom_histogram(aes(y=..count../sum(..count..)),bins=100,alpha=0.5, position = "identity",
color="white")+
theme_minimal()+
labs(title=" ",x="Conditional Average Treatment Effect",y="Density")
# load full dataset
df_repl<-read_delim("../data/FARS-data-full-sample.txt",delim = "\t")%>%
filter(year<2004)%>%
select(-starts_with("imp"))
# load small dataset
df_sel<-read_delim("../data/FARS-data-selection-sample.txt",delim = "\t")%>%
filter(year<2004)%>%
select(-starts_with("imp"))
# remove rows with missing cases
df_repl<-df_repl[complete.cases(df_repl), ]
df_sel<-df_sel[complete.cases(df_sel), ]
# Treatment indicators
df_repl<-df_repl%>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
childseat==1~"Childseat",TRUE~"NONE"),
D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
df_sel <-df_sel %>%mutate(D=case_when(lapshould==1~"LapShoulderSeat",lapbelt==1~"Lapbelt",
childseat==1~"Childseat",TRUE~"NONE"),
D=factor(D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat")),
Dbinary=case_when(lapshould==1~1,lapbelt==1~1,childseat==1~1,TRUE~0))
# Convert categorical to indicators
df_repl<-dummy_cols(df_repl%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm)
df_sel<-dummy_cols(df_sel%>%select(-restraint))%>%select(-starts_with("D_"),-crashtm)
#df_repl<-df_repl%>%mutate(day=ifelse(crashtm=="1_day",1,0),night=ifelse(crashtm=="2_night",1,0),morn=ifelse(crashtm=="3_morn",1,0))
#df_sel<- df_sel %>%mutate(day=ifelse(crashtm=="1_day",1,0),night=ifelse(crashtm=="2_night",1,0),morn=ifelse(crashtm=="3_morn",1,0))
# Select variables
#df_repl<-df_repl%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death,D,Dbinary)
#df_sel<- df_sel %>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv,death,D,Dbinary)
# Training and test data
set.seed(seed)
df_repl_split <- initial_split(df_repl, prop = .5)
df_repl_train <- training(df_repl_split)
df_repl_test <- testing(df_repl_split)
df_sel_split <- initial_split(df_sel, prop = .5)
df_sel_train <- training(df_sel_split)
df_sel_test <- testing(df_sel_split)
# X Matrices
X_repl_train<-as.matrix(df_repl_train%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv))
X_repl_test<- as.matrix(df_repl_test%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv))
X_sel_train<- as.matrix(df_sel_train%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv))
X_sel_test<- as.matrix(df_sel_test%>%select(splmU55,thoulbs_I,numcrash,weekend,lowviol,highviol,ruralrd,frimp,suv))
X_repl_train_nocontrols<-as.matrix(rep(1,nrow(X_repl_train)))
X_repl_test_nocontrols<- as.matrix(rep(1,nrow(X_repl_test)))
X_sel_train_nocontrols<- as.matrix(rep(1,nrow(X_sel_train)))
X_sel_test_nocontrols<- as.matrix(rep(1,nrow(X_sel_test)))
# D matrices
D_repl_train<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_repl_test<-factor(df_repl_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_train<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_sel_test<-factor(df_sel_train$D,levels=c("NONE","Lapbelt","LapShoulderSeat","Childseat"))
D_binary_repl_train<-as.matrix(df_repl_train%>%select(Dbinary))
D_binary_repl_test<- as.matrix(df_repl_test%>%select(Dbinary))
D_binary_sel_train<- as.matrix(df_sel_train%>%select(Dbinary))
D_binary_sel_test<- as.matrix(df_sel_test%>%select(Dbinary))
# Y matrices
Y_repl_train<-as.matrix(df_repl_train%>%select(death))
Y_repl_test<- as.matrix(df_repl_test%>%select(death))
Y_sel_train<- as.matrix(df_sel_train%>%select(death))
Y_sel_test<- as.matrix(df_sel_test%>%select(death))
cfmulti <- multi_arm_causal_forest(X=X_sel_train, Y=Y_sel_train, W=D_sel_train)
average_treatment_effect(cfmulti)
## estimate std.err contrast outcome
## Lapbelt - NONE -0.05571726 0.007624041 Lapbelt - NONE death
## LapShoulderSeat - NONE -0.05396493 0.006828650 LapShoulderSeat - NONE death
## Childseat - NONE -0.05746134 0.005700159 Childseat - NONE death
# Get importance
importance=variable_importance(cfmulti)
var_imp <- data.frame(importance=importance,names=colnames(X_sel_train))
ggplot(var_imp,aes(x= reorder(names,importance),y=importance))+
geom_bar(stat="identity",fill="#f56c42",color="white")+
theme_minimal()+
theme(axis.text.x = element_text(angle=45,vjust = 1, hjust=1))+
labs(x=" ")+
coord_flip()
### Characterising treatment effect heterogeneity
CATE distribution
Now by covariates
## policy_tree object
## Tree depth: 2
## Actions: 1: control 2: treated
## Variable splits:
## (1) split_variable: thoulbs_I split_value: 3.128
## (2) split_variable: thoulbs_I split_value: 3.101
## (4) * action: 1
## (5) * action: 2
## (3) split_variable: numcrash split_value: 17
## (6) * action: 1
## (7) * action: 2
## policy_tree object
## Tree depth: 2
## Actions: 1: NONE 2: Lapbelt 3: LapShoulderSeat 4: Childseat
## Variable splits:
## (1) split_variable: thoulbs_I split_value: 2.364
## (2) split_variable: thoulbs_I split_value: 2.362
## (4) * action: 1
## (5) * action: 2
## (3) split_variable: thoulbs_I split_value: 4.34
## (6) * action: 1
## (7) * action: 3